//
// Created by enemy1205 on 2021/10/6.
//
#include "HoughDetect.h"
#include <iostream>

LineDetect::LineDetect(Mat &img) {
    this->src = img.clone();//保存一份完整源图
    cv::GaussianBlur(img, img, cv::Size(3, 3), 0);
    cv::Canny(img, bin, 50, 150);
    this->threshold = 0;
}

LineDetect::LineDetect(Mat &img, int &thr) {
    this->src = img.clone();//保存一份完整源图
    cv::GaussianBlur(img, img, cv::Size(3, 3), 0);
    cv::Canny(img, bin, 50, 150);
    this->threshold = thr;
}

void LineDetect::calCounter(const int &thetaDim, const int &distStep) {
    maxDist = sqrt(bin.rows * bin.rows + bin.cols * bin.cols);
    this->thetaDim = thetaDim;
    this->distDim = ceilf(maxDist / (float) distStep);
    this->counter = Mat::zeros(thetaDim, distDim, CV_8UC1);
    this->vote();
    this->halfDistWindowSize = static_cast<int>(distDim / 50);
    double maxPolls = 0;
    cv::minMaxLoc(counter, nullptr, &maxPolls);
    if (threshold == 0) threshold = static_cast<int>(maxPolls * 2.3875 / 10);
}

/**
 * @brief
 * @param thetaStep
 */
void LineDetect::vote() {
    for (int i = 0; i < bin.rows; ++i) {
        for (int j = 0; j < bin.cols; ++j) {
            //对所有像素值不为0的点进行投票
            if (bin.ptr<uchar>(i)[j] != 0) {
                for (int k = 0; k < thetaDim; ++k) {
                    this->counter.ptr<uchar>(k)[int(
                            round((i * cos(k * CV_PI / thetaDim) + j * sin(k * CV_PI / thetaDim)) *
                                  (distDim / maxDist)))] += 1;
                }
            }
        }
    }
}

/**
 * @brief 非极大值抑制
 */
void LineDetect::filter() {
    std::vector<cv::Point> temp;
    for (int i = 0; i < counter.rows; ++i) {
        for (int j = 0; j < counter.cols; ++j) {
            if (counter.ptr<uchar>(i)[j] > threshold) {
                int maxVal = counter.ptr<uchar>(i)[j];
                Point maxPt(i, j);
                //遍历周围8点
                //左上
                int x = MAX(0, i - halfDistWindowSize + 1);
                int y = MAX(0, j - halfDistWindowSize + 1);
                if (counter.ptr<uchar>(x)[y] > maxVal) {
                    maxVal = counter.ptr<uchar>(x)[y];
                    maxPt = Point(x, y);
                }
                //上
                x = i;
                y = MAX(0, j - halfDistWindowSize + 1);
                if (counter.ptr<uchar>(x)[y] > maxVal) {
                    maxVal = counter.ptr<uchar>(x)[y];
                    maxPt = Point(x, y);
                }
                //右上
                x = MIN(counter.cols, i + halfDistWindowSize);
                y = MAX(0, j - halfDistWindowSize + 1);
                if (counter.ptr<uchar>(x)[y] > maxVal) {
                    maxVal = counter.ptr<uchar>(x)[y];
                    maxPt = Point(x, y);
                }
                //左
                x = MAX(0, i - halfDistWindowSize + 1);
                y = j;
                if (counter.ptr<uchar>(x)[y] > maxVal) {
                    maxVal = counter.ptr<uchar>(x)[y];
                    maxPt = Point(x, y);
                }
                //右
                x = MIN(counter.cols, i + halfDistWindowSize);
                y = j;
                if (counter.ptr<uchar>(x)[y] > maxVal) {
                    maxVal = counter.ptr<uchar>(x)[y];
                    maxPt = Point(x, y);
                }
                //左下
                x = MAX(0, i - halfDistWindowSize + 1);
                y = MIN(counter.rows, j + halfDistWindowSize);
                if (counter.ptr<uchar>(x)[y] > maxVal) {
                    maxVal = counter.ptr<uchar>(x)[y];
                    maxPt = Point(x, y);
                }
                //下
                x = i;
                y = MIN(counter.rows, j + halfDistWindowSize);
                if (counter.ptr<uchar>(x)[y] > maxVal) {
                    maxVal = counter.ptr<uchar>(x)[y];
                    maxPt = Point(x, y);
                }
                //右下
                x = MAX(0, i + halfDistWindowSize);
                y = MIN(counter.rows, j + halfDistWindowSize);
                if (counter.ptr<uchar>(x)[y] > maxVal) {
                    maxVal = counter.ptr<uchar>(x)[y];
                    maxPt = Point(x, y);
                }
                maxPt.x = maxPt.x * thetaDim / CV_PI;
                maxPt.y = maxPt.y * maxDist / distDim;
                temp.emplace_back(maxPt);
            }
        }
    }
    for (auto it = temp.begin(); it != temp.end(); it++) {
        if (it != temp.begin() && (*it) == *(it - 1)) continue;
        else this->potentialPoints.emplace_back(*it);
    }
    temp.clear();
    temp.shrink_to_fit();
}

/**
 * @brief 绘制
 */
void LineDetect::draw() {
    for (int i = 0; i < src.rows; ++i) {
        for (int j = 0; j < src.cols; ++j) {
            for (const auto pt: potentialPoints) {
                if (abs(pt.y - i * cos(pt.x) - j * sin(pt.x)) < 2)
                    src.ptr<cv::Vec3b>(i)[j] = {255, 0, 0};
            }
        }
    }
}

void LineDetect::run(const int &thetaDim, const int &distStep) {
//    std::vector<cv::Vec2f> lines; // will hold the results of the detection
//    HoughLines(bin, lines, 1, CV_PI / 180, 150, 0, 0); // runs the actual detection
//    // Draw the lines
//    for (size_t i = 0; i < lines.size(); i++) {
//        float rho = lines[i][0], theta = lines[i][1];
//        Point pt1, pt2;
//        double a = cos(theta), b = sin(theta);
//        double x0 = a * rho, y0 = b * rho;
//        pt1.x = cvRound(x0 + 1000 * (-b));
//        pt1.y = cvRound(y0 + 1000 * (a));
//        pt2.x = cvRound(x0 - 1000 * (-b));
//        pt2.y = cvRound(y0 - 1000 * (a));
//        line(src, pt1, pt2, cv::Scalar(0, 0, 255), 3, cv::LINE_AA);
//    }

    // Probabilistic Line Transform
//    std::vector<cv::Vec4i> linesP; // will hold the results of the detection
//    HoughLinesP(bin, linesP, 1, CV_PI/180, 50, 50, 10 ); // runs the actual detection
//    // Draw the lines
//    for( size_t i = 0; i < linesP.size(); i++ )
//    {
//        cv::Vec4i l = linesP[i];
//        line( src, Point(l[0], l[1]), Point(l[2], l[3]), cv::Scalar(0,0,255), 3, cv::LINE_AA);
//    }

    this->calCounter(thetaDim, distStep);
    this->filter();
    this->draw();
    cv::imshow("bin", bin);
    cv::imshow("dst", src);
    cv::waitKey(0);
}

void CircleDetect::vote() {}